import math
import random

import torch
from torch import nn
from torch.nn import functional as F

from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d


class PixelNorm(nn.Module):
	def __init__(self):
		super().__init__()

	def forward(self, input):
		return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)


def make_kernel(k):
	k = torch.tensor(k, dtype=torch.float32)

	if k.ndim == 1:
		k = k[None, :] * k[:, None]

	k /= k.sum()

	return k


class Upsample(nn.Module):
	def __init__(self, kernel, factor=2):
		super().__init__()

		self.factor = factor
		kernel = make_kernel(kernel) * (factor ** 2)
		self.register_buffer('kernel', kernel)

		p = kernel.shape[0] - factor

		pad0 = (p + 1) // 2 + factor - 1
		pad1 = p // 2

		self.pad = (pad0, pad1)

	def forward(self, input):
		out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)

		return out


class Downsample(nn.Module):
	def __init__(self, kernel, factor=2):
		super().__init__()

		self.factor = factor
		kernel = make_kernel(kernel)
		self.register_buffer('kernel', kernel)

		p = kernel.shape[0] - factor

		pad0 = (p + 1) // 2
		pad1 = p // 2

		self.pad = (pad0, pad1)

	def forward(self, input):
		out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)

		return out


class Blur(nn.Module):
	def __init__(self, kernel, pad, upsample_factor=1):
		super().__init__()

		kernel = make_kernel(kernel)

		if upsample_factor > 1:
			kernel = kernel * (upsample_factor ** 2)

		self.register_buffer('kernel', kernel)

		self.pad = pad

	def forward(self, input):
		out = upfirdn2d(input, self.kernel, pad=self.pad)

		return out


class EqualConv2d(nn.Module):
	def __init__(
		self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
	):
		super().__init__()

		self.weight = nn.Parameter(
			torch.randn(out_channel, in_channel, kernel_size, kernel_size)
		)
		self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)

		self.stride = stride
		self.padding = padding

		if bias:
			self.bias = nn.Parameter(torch.zeros(out_channel))

		else:
			self.bias = None

	def forward(self, input):
		out = F.conv2d(
			input,
			self.weight * self.scale,
			bias=self.bias,
			stride=self.stride,
			padding=self.padding,
		)

		return out

	def __repr__(self):
		return (
			'{}({}, {},'.format(self.__class__.__name__, self.weight.shape[1], self.weight.shape[0]) +
			' {}, stride={}, padding={})'.format(self.weight.shape[2], self.stride, self.padding)
		)


class EqualLinear(nn.Module):
	def __init__(
		self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
	):
		super().__init__()

		self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

		if bias:
			self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

		else:
			self.bias = None

		self.activation = activation

		self.scale = (1 / math.sqrt(in_dim)) * lr_mul
		self.lr_mul = lr_mul

	def forward(self, input):
		
		if self.activation:
			out = F.linear(input, self.weight * self.scale)
			out = fused_leaky_relu(out, self.bias * self.lr_mul)
			
		else:
			out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)

		return out

	def __repr__(self):
		return (
			'{}({}, {})'.format(self.__class__.__name__, self.weight.shape[1], self.weight.shape[0])
		)


class ScaledLeakyReLU(nn.Module):
	def __init__(self, negative_slope=0.2):
		super().__init__()

		self.negative_slope = negative_slope

	def forward(self, input):
		out = F.leaky_relu(input, negative_slope=self.negative_slope)

		return out * math.sqrt(2)


class ModulatedConv2d(nn.Module):
	def __init__(
		self,
		in_channel,
		out_channel,
		kernel_size,
		style_dim,
		demodulate=True,
		upsample=False,
		downsample=False,
		blur_kernel=[1, 3, 3, 1],
	):
		super().__init__()

		self.eps = 1e-8
		self.kernel_size = kernel_size
		self.in_channel = in_channel
		self.out_channel = out_channel
		self.upsample = upsample
		self.downsample = downsample

		if upsample:
			factor = 2
			p = (len(blur_kernel) - factor) - (kernel_size - 1)
			pad0 = (p + 1) // 2 + factor - 1
			pad1 = p // 2 + 1

			self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

		if downsample:
			factor = 2
			p = (len(blur_kernel) - factor) + (kernel_size - 1)
			pad0 = (p + 1) // 2
			pad1 = p // 2

			self.blur = Blur(blur_kernel, pad=(pad0, pad1))

		fan_in = in_channel * kernel_size ** 2
		self.scale = 1 / math.sqrt(fan_in)
		self.padding = kernel_size // 2

		self.weight = nn.Parameter(
			torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
		)

		self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)

		self.demodulate = demodulate

	def __repr__(self):
		return (
			'{}({}, {}, {}, '.format(self.__class__.__name__, self.in_channel, self.out_channel, self.kernel_size) +
			'upsample={}, downsample={})'.format(self.upsample, self.downsample)
		)

	def forward(self, input, style, weights_delta=None, offsets = None):
		batch, in_channel, height, width = input.shape

		style = self.modulation(style).view(batch, 1, in_channel, 1, 1)

		if offsets is not None: # I want to try the same with hyperdomain but before the demod action
			self.shape = self.weight.shape
			self.ones = torch.ones(self.shape).cuda()
			mult = self.ones + offsets 
			weight =self.scale * (self.weight * mult * style)
		else:	
			if weights_delta is None:
				weight = self.scale * self.weight * style
			else:
				# print(style.shape, self.weight.shape, weights_delta.shape)
				weight = self.scale * (self.weight * (1 + weights_delta) * style)
		
		

		if self.demodulate:
			demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
			weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

		###### Code taken from hyperdomain ######
		# if offsets is not None:
		#     print(offsets.shape)
		#     self.shape = weight.shape
		#     self.ones = torch.ones(self.shape).cuda()
		#     mult = self.ones + offsets # ?? offsets shape? B X 1 X 512 x 1 x 1 (?)
		#     weight = weight * mult
		##########################################

		weight = weight.view(
			batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
		)

		if self.upsample:
			input = input.view(1, batch * in_channel, height, width)
			weight = weight.view(
				batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
			)
			weight = weight.transpose(1, 2).reshape(
				batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
			)
			out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
			_, _, height, width = out.shape
			out = out.view(batch, self.out_channel, height, width)
			out = self.blur(out)

		elif self.downsample:
			input = self.blur(input)
			_, _, height, width = input.shape
			input = input.view(1, batch * in_channel, height, width)
			out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
			_, _, height, width = out.shape
			out = out.view(batch, self.out_channel, height, width)

		else:
			input = input.view(1, batch * in_channel, height, width)
			out = F.conv2d(input, weight, padding=self.padding, groups=batch)
			_, _, height, width = out.shape
			out = out.view(batch, self.out_channel, height, width)

		return out


class NoiseInjection(nn.Module):
	def __init__(self):
		super().__init__()

		self.weight = nn.Parameter(torch.zeros(1))

	def forward(self, image, noise=None):
		if noise is None:
			batch, _, height, width = image.shape
			noise = image.new_empty(batch, 1, height, width).normal_()

		return image + self.weight * noise


class ConstantInput(nn.Module):
	def __init__(self, channel, size=4):
		super().__init__()

		self.input = nn.Parameter(torch.randn(1, channel, size, size))

	def forward(self, input):
		batch = input.shape[0]
		out = self.input.repeat(batch, 1, 1, 1)

		return out


class StyledConv(nn.Module):
	def __init__(
		self,
		in_channel,
		out_channel,
		kernel_size,
		style_dim,
		upsample=False,
		blur_kernel=[1, 3, 3, 1],
		demodulate=True,
	):
		super().__init__()

		self.conv = ModulatedConv2d(
			in_channel,
			out_channel,
			kernel_size,
			style_dim,
			upsample=upsample,
			blur_kernel=blur_kernel,
			demodulate=demodulate,
		)

		self.noise = NoiseInjection()
		# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
		# self.activate = ScaledLeakyReLU(0.2)
		self.activate = FusedLeakyReLU(out_channel)

	def forward(self, input, style, noise=None, weights_delta=None, offsets = None):
		out = self.conv(input, style, weights_delta=weights_delta, offsets = offsets)
		out = self.noise(out, noise=noise)
		out = self.activate(out)

		return out


class ToRGB(nn.Module):
	def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
		super().__init__()

		if upsample:
			self.upsample = Upsample(blur_kernel)

		self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
		self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))

	def forward(self, input, style, skip=None, weights_delta=None):
		out = self.conv(input, style, weights_delta)
		out = out + self.bias

		if skip is not None:
			skip = self.upsample(skip)

			out = out + skip

		return out

class Generator(nn.Module):
	def __init__(
		self,
		size,
		style_dim,
		n_mlp,
		channel_multiplier=2,
		blur_kernel=[1, 3, 3, 1],
		lr_mlp=0.01,
	):
		
		super().__init__()

		self.size = size

		self.style_dim = style_dim

		layers = [PixelNorm()]

		for i in range(n_mlp):
			layers.append(
				EqualLinear(
					style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
				)
			)
		
		self.style = nn.Sequential(*layers)

		self.channels = {
			4: 512,
			8: 512,
			16: 512,
			32: 512,
			64: 256 * channel_multiplier,
			128: 128 * channel_multiplier,
			256: 64 * channel_multiplier,
			512: 32 * channel_multiplier,
			1024: 16 * channel_multiplier,
		}

		self.input = ConstantInput(self.channels[4])
		self.conv1 = StyledConv(
			self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
		)
		self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

		self.log_size = int(math.log(size, 2))
		self.num_layers = (self.log_size - 2) * 2 + 1
		

		self.convs = nn.ModuleList()
		self.upsamples = nn.ModuleList()
		self.to_rgbs = nn.ModuleList()
		self.noises = nn.Module()

		in_channel = self.channels[4]

		for layer_idx in range(self.num_layers):
			res = (layer_idx + 5) // 2
			shape = [1, 1, 2 ** res, 2 ** res]
			self.noises.register_buffer('noise_{}'.format(layer_idx), torch.randn(*shape))

		for i in range(3, self.log_size + 1):
			out_channel = self.channels[2 ** i]

			self.convs.append(
				StyledConv(
					in_channel,
					out_channel,
					3,
					style_dim,
					upsample=True,
					blur_kernel=blur_kernel,
				)
			)

			self.convs.append(
				StyledConv(
					out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
				)
			)

			self.to_rgbs.append(ToRGB(out_channel, style_dim))

			in_channel = out_channel

		self.n_latent = self.log_size * 2 - 2

	def make_noise(self):
		device = self.input.input.device

		noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]

		for i in range(3, self.log_size + 1):
			for _ in range(2):
				noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))

		return noises

	def mean_latent(self, n_latent):
		latent_in = torch.randn(
			n_latent, self.style_dim, device=self.input.input.device
		)
		latent = self.style(latent_in).mean(0, keepdim=True)

		return latent

	def get_latent(self, input, truncation=1, truncation_latent=None):
		w = self.style(input)
		if truncation_latent is not None:
			style_t = []
			for style in w:
				style_t.append(
					truncation_latent + truncation * (style - truncation_latent)
				)
			w = style_t
		return w

	def forward(
		self,
		styles,
		conditions = None,
		return_latents=False,
		return_features=False,
		inject_index=None,
		truncation=1,
		truncation_latent=None,
		input_is_latent=False,
		noise=None,
		randomize_noise=False,
		weights_deltas=None,
		offsets = None
	):  
		
		## Hyperstyle ##
		total_convs = len(self.convs) + len(self.to_rgbs) + 2   # +2 for first conv and toRGB
		if weights_deltas is None:
			weights_deltas = [None] * total_convs
		################

		if not input_is_latent:
			styles = [self.style(s) for s in styles]

		
		if noise is None:
			if randomize_noise:
				noise = [None] * self.num_layers
			else:
				noise = [getattr(self.noises, 'noise_{}'.format(i)) for i in range(self.num_layers)]

		if truncation < 1:
			style_t = []
			for style in styles:
				style_t.append(
					truncation_latent + truncation * (style - truncation_latent)
				)
			styles = style_t
		
		if len(styles) < 2:
			inject_index = self.n_latent
			if styles[0].ndim < 3:
				latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
				
			else:
				latent = styles[0]
		
		else:
			if inject_index is None:
				inject_index = random.randint(1, self.n_latent - 1)

			latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
			latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

			latent = torch.cat([latent, latent2], 1)

		out = self.input(latent)
		# out = self.conv1(out, latent[:, 0], noise=noise[0], weights_delta=weights_deltas[0])
		out = self.conv1(
			out, latent[:, 0],
			offsets = offsets['conv_0'] if offsets is not None else None,
			noise = noise[0], weights_delta=weights_deltas[0])
		
		skip = self.to_rgb1(out, latent[:, 1], weights_delta=weights_deltas[1])
		
		# print(self.conv1.conv.weight.data.shape)
		# print(self.to_rgb1.conv.weight.data.shape)
		# print(out.shape)
		# print(skip.shape)
		i = 1
		if conditions is None:
			count = 1
			weight_idx = 2 # added for hyperstyle
			for conv1, conv2, noise1, noise2, to_rgb in zip(
				self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
			):
				# out = conv1(out, latent[:, i], noise=noise1)
				# out = conv2(out, latent[:, i + 1], noise=noise2)
				# skip = to_rgb(out, latent[:, i + 2], skip)

				# out = conv1(out, latent[:, i], noise=noise1, weights_delta=weights_deltas[weight_idx])
				# out = conv2(out, latent[:, i + 1], noise=noise2, weights_delta=weights_deltas[weight_idx + 1])
				out = conv1(
					out, latent[:, i],
					offsets = offsets['conv_{}'.format(i)] if offsets is not None else None,
					noise=noise1 , weights_delta=weights_deltas[weight_idx]
				)
				# print('conv1', out.shape)

				out = conv2(
					out, latent[:, i + 1],
					offsets = offsets['conv_{}'.format(i + 1)] if offsets is not None else None,
					noise=noise2, weights_delta=weights_deltas[weight_idx + 1]
				)
				skip = to_rgb(out, latent[:, i + 2], skip, weights_delta=weights_deltas[weight_idx + 2])

				# print('conv1', out.shape)
				# print('to_rgb', skip.shape)

				# print('conv1', conv1.conv.weight.data.shape)
				# print('conv2', conv2.conv.weight.data.shape)
				# print('to_rgb', to_rgb.conv.weight.data.shape)

				i += 2
				weight_idx += 3 # added for hyperstyle
		else:
			for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs):
				out = conv1(out, latent[:, i], noise=noise1)
				if i==7 and conditions is not None:
					#11 for 256, 9 for 128， 7 for 64
					out  =  out  * (1 + conditions[0]) + conditions[1]
				out = conv2(out, latent[:, i + 1], noise=noise2)
				skip = to_rgb(out, latent[:, i + 2], skip)
				i += 2

		image = skip
		
		if return_latents:
			return image, latent
		elif return_features:  # hyperstyle
			return image, out
		else:
			return image, None


class ConvLayer(nn.Sequential):
	def __init__(
		self,
		in_channel,
		out_channel,
		kernel_size,
		downsample=False,
		blur_kernel=[1, 3, 3, 1],
		bias=True,
		activate=True,
	):
		layers = []

		if downsample:
			factor = 2
			p = (len(blur_kernel) - factor) + (kernel_size - 1)
			pad0 = (p + 1) // 2
			pad1 = p // 2

			layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

			stride = 2
			self.padding = 0

		else:
			stride = 1
			self.padding = kernel_size // 2

		layers.append(
			EqualConv2d(
				in_channel,
				out_channel,
				kernel_size,
				padding=self.padding,
				stride=stride,
				bias=bias and not activate,
			)
		)

		if activate:
			if bias:
				layers.append(FusedLeakyReLU(out_channel))

			else:
				layers.append(ScaledLeakyReLU(0.2))

		super().__init__(*layers)


class ResBlock(nn.Module):
	def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
		super().__init__()

		self.conv1 = ConvLayer(in_channel, in_channel, 3)
		self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)

		self.skip = ConvLayer(
			in_channel, out_channel, 1, downsample=True, activate=False, bias=False
		)

	def forward(self, input):
		out = self.conv1(input)
		out = self.conv2(out)

		skip = self.skip(input)
		out = (out + skip) / math.sqrt(2)

		return out


class Discriminator(nn.Module):
	def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
		super().__init__()

		channels = {
			4: 512,
			8: 512,
			16: 512,
			32: 512,
			64: 256 * channel_multiplier,
			128: 128 * channel_multiplier,
			256: 64 * channel_multiplier,
			512: 32 * channel_multiplier,
			1024: 16 * channel_multiplier,
		}

		convs = [ConvLayer(3, channels[size], 1)]

		log_size = int(math.log(size, 2))

		in_channel = channels[size]

		for i in range(log_size, 2, -1):
			out_channel = channels[2 ** (i - 1)]

			convs.append(ResBlock(in_channel, out_channel, blur_kernel))

			in_channel = out_channel

		self.convs = nn.Sequential(*convs)

		self.stddev_group = 4
		self.stddev_feat = 1

		self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
		self.final_linear = nn.Sequential(
			EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
			EqualLinear(channels[4], 1),
		)

		self.log_size = log_size

	def forward(self, input):        
		out = self.convs(input)
		batch, channel, height, width = out.shape
		# group = min(batch, self.stddev_group)
		group = max(batch, self.stddev_group)
		# print(out.shape)
		# print(group, batch, self.stddev_group, channel // self.stddev_feat)
		
		stddev = out.view(
			batch, -1, self.stddev_feat, channel // self.stddev_feat, height, width
		)
		stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
		stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
		stddev = stddev.repeat(group, 1, height, width)

		out = torch.cat([out, stddev], 1)
		
		out = self.final_conv(out)

		out = out.view(batch, -1)
		out = self.final_linear(out)

		return out

class Encoder(nn.Module):
	def __init__(self, size, w_dim=512):
		super().__init__()
		
		channels = {
			4: 512,
			8: 512,
			16: 512,
			32: 512,
			64: 256,
			128: 128,
			256: 64,
			512: 32,
			1024: 16
		}        
		
		self.w_dim = w_dim
		log_size = int(math.log(size, 2))
		
		self.n_latents = log_size*2 - 2
		
		convs = [ConvLayer(3, channels[size], 1)]

		in_channel = channels[size]
		for i in range(log_size, 2, -1):
			out_channel = channels[2 ** (i - 1)]
			convs.append(ResBlock(in_channel, out_channel))
			in_channel = out_channel

		convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False))    

		self.convs = nn.Sequential(*convs)

	def forward(self, input):
		out = self.convs(input)
		# print('Encoder weights: {:.5f}'.format( torch.sum(torch.abs( self.convs[0][0].weight.data )) ))

		return out.view(len(input), self.n_latents, self.w_dim)